{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compress and Evaluate Reasoning Large Language Models"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "<a target=\"_blank\" href=\"https://colab.research.google.com/github/PrunaAI/pruna/blob/v|version|/docs/tutorials/reasoning_llm.ipynb\">\n",
    "    <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
    "</a>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "| Component | Details |\n",
    "|-----------|---------|\n",
    "| **Goal** | Showcase a standard workflow for optimizing and evaluating a reasoning Large Language Model |\n",
    "| **Model** |[Qwen/Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B) |\n",
    "| **Dataset** | [zwhe99/DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K)  |\n",
    "| **Device** | 1 x H100 (80GB) |\n",
    "| **Optimization Algorithms** | quantizer(hqq), compiler(torch_compile) |\n",
    "| **Evaluation Metrics** | `total time`, `perplexity`, `throughput`, `energy_consumed` |"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Getting Started\n",
    "\n",
    "To install the required dependencies, you can run the following command:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install pruna"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For more information about how to install Pruna, please refer to the [Installation](https://docs.pruna.ai/en/stable/setup/install.html) page.\n",
    "\n",
    "Then, we will set the device to the best available option to maximize the optimization process's benefits. However, in this case, we recommend using a GPU."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load the Model\n",
    "\n",
    "First, we will load the original model and tokenizer using the transformers library. In our case, we will use one of the small versions of Qwen3, [Qwen/Qwen3-1.7B](https://huggingface.co/Qwen/Qwen3-1.7B) just as a starting point. However, Pruna works at least as well with larger models, so feel free to use a bigger version of Qwen3 or any other [reasoning model available on Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "model_name = \"Qwen/Qwen3-1.7B\"\n",
    "\n",
    "pipe = pipeline(\n",
    "    \"text-generation\",\n",
    "    model_name,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once we've loaded the model and tokenizer, we can try to generate a response from the model and parse the response to get the reasoning steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'role': 'user',\n",
       "  'content': 'Give me a short introduction to large language model.'},\n",
       " {'role': 'assistant',\n",
       "  'content': 'Large language models (LLMs) are AI systems designed to understand, generate, and interact with human language. They are trained on massive datasets of text, enabling them to grasp complex patterns and produce coherent, context-aware responses. These models, often based on transformer architecture, excel in tasks like translation, writing, and answering questions. While they offer remarkable capabilities, they also face challenges such as data bias and the need for continuous refinement. LLMs are revolutionizing industries by enhancing productivity and innovation in areas like customer service, content creation, and research.',\n",
       "  'reasoning_content': 'Okay, the user wants a short introduction to large language models. Let me start by defining what they are. Large language models (LLMs) are AI systems trained on vast amounts of text data. I should mention their key features like natural language understanding and generation.\\n\\nWait, I need to make sure it\\'s concise. Maybe start with a definition, then talk about their training, the components like transformer architecture, and their applications. Also, mention that they\\'re used in various fields like customer service, content creation, etc. Oh, and maybe touch on their limitations, like data bias or lack of real-world knowledge. But since it\\'s a short intro, maybe keep it positive and highlight the benefits first.\\n\\nLet me check if I\\'m missing anything. The user might be a student or someone new to AI. They need a clear, straightforward explanation without too much jargon. Make sure to explain key terms like \"training data\" and \"transformer architecture\" in simple terms. Avoid technical details that might confuse them. Alright, structure it as a brief overview with key points.'}]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import copy\n",
    "import re\n",
    "\n",
    "\n",
    "# Helper function to parse the thinking content\n",
    "def parse_thinking_content(messages):  # noqa: D103\n",
    "    messages = copy.deepcopy(messages)\n",
    "    for message in messages:\n",
    "        if message[\"role\"] == \"assistant\" and (\n",
    "            m := re.match(r\"<think>\\n(.+)</think>\\n\\n\", message[\"content\"], flags=re.DOTALL)\n",
    "        ):\n",
    "            message[\"content\"] = message[\"content\"][len(m.group(0)) :]\n",
    "            if thinking_content := m.group(1).strip():\n",
    "                message[\"reasoning_content\"] = thinking_content\n",
    "    return messages\n",
    "\n",
    "\n",
    "# Run the model\n",
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"Give me a short introduction to large language model.\",\n",
    "    },\n",
    "]\n",
    "messages = pipe(messages, max_new_tokens=32768)[0][\"generated_text\"]\n",
    "\n",
    "parse_thinking_content(messages)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Define the SmashConfig"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that our base model is loaded and tested, we can specify the `SmashConfig` to customize the optimizations applied during smashing.\n",
    "\n",
    "Not every optimization algorithm works with every model. You can learn about the requirements and compatibility in the [Algorithms Overview](https://docs.pruna.ai/en/stable/compression.html).\n",
    "\n",
    "In this example, we will enable [hqq](https://docs.pruna.ai/en/stable/compression.html#hqq) quantization to improve the performance of the model and [torch_compile](https://docs.pruna.ai/en/stable/compression.html#torch-compile) compilation to improve the speed of the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pruna import SmashConfig\n",
    "\n",
    "smash_config = SmashConfig(\n",
    "    {\n",
    "        \"hqq\": {\"weight_bits\": 8, \"compute_dtype\": \"torch.bfloat16\"},\n",
    "        \"torch_compile\": {\"fullgraph\": True, \"dynamic\": True}\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Smash the Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have our `SmashConfig` defined, it’s time to apply it to our base model. We’ll call the `smash` function with the base model and our `SmashConfig`\n",
    "\n",
    "Ready to smash? This operation typically takes around 20 seconds, depending on the configuration."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pruna import smash\n",
    "\n",
    "copy_model = copy.deepcopy(pipe.model).to(\"cpu\")\n",
    "smashed_model = smash(\n",
    "    model=pipe.model,\n",
    "    smash_config=smash_config,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Great! Now we have our optimized smashed model. Let's check how it works by running some inference.\n",
    "\n",
    "Consider that if you are using `torch_compile` as a compiler, you can expect the first inference warmup to take a bit longer than the actual inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'role': 'user',\n",
       "  'content': 'Give me a short introduction to large language models.'},\n",
       " {'role': 'assistant',\n",
       "  'content': \"Large language models (LLMs) are advanced AI systems designed to understand and generate human-like text. They learn from vast amounts of data using deep learning techniques, enabling them to produce coherent and contextually relevant responses. These models excel in tasks like language translation, content creation, and customer service chatbots. While they're powerful, they're not infallible and rely on data quality. Their integration into daily life has transformed how we interact with technology, making tasks faster and more efficient.\",\n",
       "  'reasoning_content': \"Okay, the user wants a short introduction to large language models. Let me start by defining what they are. Large language models are AI systems that can understand and generate human-like text. I should mention their training with vast amounts of data and their use in various applications like chatbots and content creation.\\n\\nWait, maybe I should break it down into key points: what they are, how they work, and their applications. Keep it concise. Also, the user might be a beginner, so avoid technical jargon. Maybe mention neural networks and deep learning. Oh, and the difference between models like GPT and others. But since it's a short intro, maybe just a few examples. \\n\\nI need to ensure the explanation is clear but not too detailed. Maybe start with a simple definition, then talk about their training, then their uses. Also, note that they're not just data processors but have some level of understanding. Make sure it's friendly and easy to grasp. Let me check the key points again: definition, training, applications, and maybe a sentence on their limitations. But since it's a short intro, maybe keep it to the basics. Alright, time to put it all together in a few sentences.\"}]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "messages = [\n",
    "    {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": \"Give me a short introduction to large language models.\",\n",
    "    },\n",
    "]\n",
    "messages = pipe(messages, max_new_tokens=32768)[0][\"generated_text\"]\n",
    "parse_thinking_content(messages)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see, the model still generates a similar response with a thinking process.\n",
    "\n",
    "If you notice a significant difference, it might be due to the model, the configuration, the hardware, etc. As optimization can be non-deterministic, we encourage you to retry the optimization process or try out different configurations and models to find the best fit for your use case. However, feel free to reach out to us on [Discord]([https://discord.gg/JFQmtFKCjd](https://discord.gg/JFQmtFKCjd)) if you have any questions or feedback."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Evaluate the Smashed Model\n",
    "\n",
    "As our smashed model is working, we can evaluate how much it has improved with our optimization. For this, we can run an evaluation of the performance using the `EvaluationAgent` and [zwhe99/DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K), as our reasoning [custom dataset](https://docs.pruna.ai/en/v0.2.9/docs_pruna/user_manual/evaluate.html#prunadatamodule). In this case, we will also include metrics like `total time`,`perplexity`, `throughput` and `energy_consumed`.\n",
    "\n",
    "A complete list of the available metrics can be found in [Evaluation](https://docs.pruna.ai/en/stable/reference/evaluation.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "from pruna import PrunaModel\n",
    "from pruna.data.pruna_datamodule import PrunaDataModule\n",
    "from pruna.data.utils import split_train_into_train_val_test\n",
    "from pruna.evaluation.evaluation_agent import EvaluationAgent\n",
    "from pruna.evaluation.metrics import (\n",
    "    EnergyConsumedMetric,\n",
    "    ThroughputMetric,\n",
    "    TorchMetricWrapper,\n",
    "    TotalTimeMetric,\n",
    ")\n",
    "from pruna.evaluation.task import Task\n",
    "\n",
    "# Define the metrics. Increment the number of iterations\n",
    "# and warmup iterations to get a more accurate result.\n",
    "metrics = [\n",
    "    TotalTimeMetric(n_iterations=50, n_warmup_iterations=5),\n",
    "    ThroughputMetric(n_iterations=50, n_warmup_iterations=5),\n",
    "    TorchMetricWrapper(\"perplexity\", call_type=\"single\"),\n",
    "    EnergyConsumedMetric(n_iterations=50, n_warmup_iterations=5),\n",
    "]\n",
    "\n",
    "# Load the dataset and split it into train, validation and test\n",
    "train_ds = load_dataset(\"zwhe99/DeepMath-103K\", split=\"train\")\n",
    "train_ds = train_ds.rename_column(\n",
    "    \"question\", \"text\"\n",
    ")  # Rename the column to match the `text_generation_collate` function\n",
    "train_ds, val_ds, test_ds = split_train_into_train_val_test(train_ds, seed=42)\n",
    "\n",
    "# (Optional) Use the eos_token as the pad_token\n",
    "pipe.tokenizer.pad_token = pipe.tokenizer.eos_token\n",
    "\n",
    "# Create the data module. Increment the `max_seq_len` to match the\n",
    "# `max_new_tokens` of the model for a more accurate evaluation.\n",
    "datamodule = PrunaDataModule.from_datasets(\n",
    "    datasets=(train_ds, val_ds, test_ds),\n",
    "    collate_fn=\"text_generation_collate\",\n",
    "    tokenizer=pipe.tokenizer,\n",
    "    collate_fn_args={\"max_seq_len\": 512},\n",
    "    dataloader_args={\"batch_size\": 16, \"num_workers\": 4},\n",
    ")\n",
    "datamodule.limit_datasets(100)\n",
    "\n",
    "# Define the task and the evaluation agent\n",
    "task = Task(metrics, datamodule=datamodule, device=device)\n",
    "eval_agent = EvaluationAgent(task)\n",
    "\n",
    "# (Optional) Define specific inference arguments for benchmarking.\n",
    "inference_args = {\n",
    "    \"max_new_tokens\": 512,  # Increment the `max_new_tokens` for a more accurate evaluation.\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pruna.engine.utils import move_to_device\n",
    "\n",
    "# Evaluate the smashed model and offload it to CPU\n",
    "move_to_device(smashed_model, device)\n",
    "smashed_model.inference_handler.model_args.update(inference_args)\n",
    "smashed_model_results = eval_agent.evaluate(smashed_model)\n",
    "move_to_device(smashed_model, \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Evaluate the base model and offload it to CPU\n",
    "base_pipe = PrunaModel(model=copy_model)\n",
    "move_to_device(base_pipe, device)\n",
    "base_pipe.inference_handler.model_args.update(inference_args)\n",
    "base_model_results = eval_agent.evaluate(base_pipe)\n",
    "move_to_device(base_pipe, \"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can see the results of the evaluation and compare the performance of the original and the optimized model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "| Metric | Base Model | Smashed Model | Improvement % |\n",
       "|-----|-----|-----|-----|\n",
       "| perplexity | 3.3330  | 2.8230  | 15.30% |\n",
       "| total_time | 42390.9036 ms | 6869.6069 ms | 83.79% |\n",
       "| throughput | 0.0189 num_iterations/ms | 0.1165 num_iterations/ms | 517.08% |\n",
       "| energy_consumed | 0.0059 kWh | 0.0011 kWh | 81.92% |"
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from IPython.display import Markdown, display  # noqa\n",
    "\n",
    "\n",
    "def make_comparison_table(base_model_results, smashed_model_results):  # noqa\n",
    "    header = \"| Metric | Base Model | Smashed Model | Improvement % |\\n\"\n",
    "    header += \"|\" + \"-----|\" * 4 + \"\\n\"\n",
    "    rows = []\n",
    "\n",
    "    for base, smashed in zip(base_model_results, smashed_model_results):\n",
    "        base_result = base.result\n",
    "        smashed_result = smashed.result\n",
    "        if base.higher_is_better:\n",
    "            diff = ((smashed_result - base_result) / base_result) * 100\n",
    "        else:\n",
    "            diff = ((base_result - smashed_result) / base_result) * 100\n",
    "        row = f\"| {base.name} | {base_result:.4f} {base.metric_units or ''}\"\n",
    "        row += f\"| {smashed_result:.4f} {smashed.metric_units or ''} | {diff:.2f}% |\"\n",
    "        rows.append(row)\n",
    "    return header + \"\\n\".join(rows)\n",
    "\n",
    "\n",
    "display(Markdown(make_comparison_table(base_model_results, smashed_model_results)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As expected, we can observe a significant improvement. The compressed model is almost 6× faster and delivers over 5× more throughput. Even better, we didn’t lose performance (remember, lower perplexity means better results), and energy use went down too. This really is the best-case scenario! \n",
    "\n",
    "With this results, we can save the optimized model to disk or share it with others:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the model to disk\n",
    "smashed_model.save_pretrained(\"Qwen3-1.7B-smashed\")\n",
    "# Load the model from disk\n",
    "# smashed_model = PrunaModel.from_pretrained(\"Qwen3-1.7B-smashed/\")\n",
    "\n",
    "# Save the model to HuggingFace\n",
    "# smashed_model.push_to_hub(\"PrunaAI/Qwen3-1.7B-smashed\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusions\n",
    "\n",
    "In this tutorial, we have seen how to optimize and evaluate a reasoning Large Language Model using Pruna. We have seen how to use the `SmashConfig` to customize the optimizations applied during smashing and how to evaluate the performance of the optimized model using the `EvaluationAgent`.\n",
    "\n",
    "The results show that by compressing the model and combining different algorithms, we can achieve a significant improvement in performance without losing accuracy.\n",
    "\n",
    "Check out our other [tutorials](https://docs.pruna.ai/en/stable/docs_pruna/tutorials/index.html) for more examples on how to optimize and evaluate image/video generation models or LLM models."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pruna0",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
